using LinearAlgebra
using HDF5
using Plots


function potential_array_from_h5(fname::String)::Array{Float64,4}
    hfile = h5open(fname, "r");
    potential_array = read(hfile["Dataset1"]);
    close(hfile);
    potential_array = permutedims(potential_array, (1,4,3,2));

    return potential_array
end


function compute_step_size_arr(potential_array::Array{Float64, 4})::Vector{Float64}
    step_size_x = potential_array[1,2,1,1] - potential_array[1,1,1,1];
    step_size_y = potential_array[2,1,2,1] - potential_array[2,1,1,1];
    step_size_z = potential_array[3,1,1,2] - potential_array[3,1,1,1];
    # Convert from nm
    step_size_vec = 1e-9 * [step_size_x, step_size_y, step_size_z];

    return step_size_vec
end


function potential_gridvalues(potential_array::Array{Float64, 4})::Array{Float64, 3}
    # Fix sign and convert from MHz
    potential = 2pi*hbar * (4e6 .- 1e6 * potential_array[4,:,:,:])

    return potential
end


function create_idx_pos_dict(potential_array::Array{Float64, 4})::Tuple{Vector{Vector{Float64}}, Dict{Vector{Float64}, Int64}}
    # Convert from nm
    position_matrix = 1e-9 * reshape(permutedims(potential_array[1:3, :, :, :], (1,4,3,2)), 3, :)

    idx_to_position = [position_matrix[:,ii] for ii in axes(position_matrix, 2)]
    position_to_idx = Dict(idx_to_position .=> eachindex(idx_to_position))

    return idx_to_position, position_to_idx
end


# Below this line, efficiency might matter!

# lvl0 is associated with potential 0
function _dipole_interactions_inner_integral(
    state_N0::Vector{Float64}, # phi_alpha'
    state_N1::Vector{Float64}, # tilde phi_beta'
    tweezer_separation_vector::Vector{Float64},
    polarization_vector::Vector{Float64},
    mol1_position_vector::Vector{Float64},
    idx_to_position::Vector{Vector{Float64}},
)
    @assert norm(polarization_vector) == 1
    distance_vector = similar(tweezer_separation_vector)

    integral = mapreduce(+, enumerate(idx_to_position)) do (ii, delta_vector)
        overlap = state_N0[ii] * state_N1[ii]
        distance_vector .= tweezer_separation_vector .- mol1_position_vector .+ delta_vector
        distance_norm = norm(distance_vector)
        cos_theta = (distance_vector' * polarization_vector) / distance_norm

        return (1 - 3cos_theta^2) / distance_norm^3 * overlap
    end

    return integral
end


function dipole_interactions_full_integral(
    state_N0::Vector{Float64}, # phi_beta
    state_N0p::Vector{Float64}, # phi_alpha'
    state_N1::Vector{Float64}, # tilde phi_alpha
    state_N1p::Vector{Float64}, # tilde phi_beta'
    tweezer_separation_vector::Vector{Float64},
    polarization_vector::Vector{Float64},
    idx_to_position::Vector{Vector{Float64}}
)

    # 1/3 for plus-minus Hamiltonian
    interaction = 1/(2pi*hbar) * (1/3*d_NaCs^2)/(4pi*eps_0) *
        mapreduce(+, enumerate(idx_to_position)) do (ii, mol1_position_vector)
            overlap = state_N0[ii] * state_N1[ii]
            inner_integral = _dipole_interactions_inner_integral(
                state_N0p, state_N1p,
                tweezer_separation_vector, polarization_vector, mol1_position_vector,
                idx_to_position
            )

            return overlap * inner_integral
        end

    return interaction
end


function dipole_interaction_matrix(
    state_N0_arr::Matrix{Float64},
    state_N1_arr::Matrix{Float64},
    levels::Vector{Int64},
    tweezer_separation_vector::Vector{Float64},
    polarization_vector::Vector{Float64},
    idx_to_position::Vector{Vector{Float64}},
)
    
    map(Iterators.product(levels, levels, levels, levels)) do (lvl1, lvl0p, lvl0, lvl1p)
        dipole_interactions_full_integral(
            (state_N0_arr[:,lvl0]), (state_N0_arr[:,lvl0p]),
            (state_N1_arr[:,lvl1]), (state_N1_arr[:,lvl1p]),
            tweezer_separation_vector,
            polarization_vector,
            idx_to_position
        )
    end
end


# lvl0 is associated with potential 0
function _dipole_interactions_inner_integral_vectorized(
    state_N0_arr_transpose::Matrix{Float64}, # phi_alpha' = state_N0_arr_transpose[alpha',:]
    state_N1_arr_transpose::Matrix{Float64}, # tilde phi_beta' = state_N1_arr_transpose[beta',:]
    tweezer_separation_vector::Vector{Float64},
    polarization_vector::Vector{Float64},
    mol1_position_vector::Vector{Float64},
    idx_to_position::Vector{Vector{Float64}},
)
    @assert norm(polarization_vector) == 1
    distance_vector = similar(tweezer_separation_vector)
    overlap = zeros(size(state_N1_arr_transpose,1), size(state_N0_arr_transpose,1))
    integrand = similar(overlap)
    integral = similar(overlap)

    integral .= mapreduce(.+, enumerate(idx_to_position)) do (ii, delta_vector)
        @views overlap .= state_N1_arr_transpose[:,ii] .* state_N0_arr_transpose[:,ii]' # first dimension: alpha'; second dimension beta'
        distance_vector .= tweezer_separation_vector .- mol1_position_vector .+ delta_vector;
        distance_norm = norm(distance_vector);
        cos_theta = (distance_vector' * polarization_vector) / distance_norm;
        integrand .= ((1 - 3cos_theta^2) / distance_norm^3) .* overlap;

        return integrand
    end

    return integral
end


function dipole_interactions_full_integral_vectorized(
    state_N0_arr_transposed::Matrix{Float64}, # phi_beta
    state_N0p_arr_transposed::Matrix{Float64}, # phi_alpha'
    state_N1_arr_transposed::Matrix{Float64}, # tilde phi_alpha
    state_N1p_arr_transposed::Matrix{Float64}, # tilde phi_beta'
    tweezer_separation_vector::Vector{Float64},
    polarization_vector::Vector{Float64},
    idx_to_position::Vector{Vector{Float64}};
    return_dicts::Bool=false
)

    integrand = zeros(
        size(state_N0_arr_transposed, 1) *
        size(state_N1p_arr_transposed, 1),
        size(state_N0p_arr_transposed, 1) *
        size(state_N1_arr_transposed, 1)
    )
    interaction = similar(integrand)

    interaction .= ( 1/(2pi*hbar) * (1/3*d_NaCs^2)/(4pi*eps_0) ) .*
        mapreduce(.+, enumerate(idx_to_position)) do (ii, mol1_position_vector)
            # alpha first dimension, beta second dimension
            @views overlap = state_N0_arr_transposed[:,ii] .* state_N1_arr_transposed[:,ii]'
            # alpha' first dimension, beta' second dimension
            inner_integral = _dipole_interactions_inner_integral_vectorized(
                state_N0p_arr_transposed, state_N1p_arr_transposed,
                tweezer_separation_vector, polarization_vector, mol1_position_vector,
                idx_to_position
            )

            kron!(integrand, overlap, inner_integral)
            return integrand
        end

    if return_dicts
        idx_to_state = kron(string.(axes(state_N0_arr_transposed,1)), ["_"], string.(axes(state_N1p_arr_transposed,1)))

        return interaction, idx_to_state
    end

    return interaction
end


# Return four ↑↑, ↑↓, ↓↑, ↓↓
function compute_diagonal_hamiltonians(
    energies_N0_left::Vector{Float64},
    energies_N1M_left::Vector{Float64},
    energies_N0_right::Vector{Float64},
    energies_N1M_right::Vector{Float64},
    level_arr::Vector{Int64};
    return_dicts=false
)

    Duu = diagm(0 => (energies_N1M_left[level_arr]' .+ energies_N1M_right[level_arr])[:])
    Dud = diagm(0 => (energies_N1M_left[level_arr]' .+ energies_N0_right[level_arr])[:])
    Ddu = diagm(0 => (energies_N0_left[level_arr]' .+ energies_N1M_right[level_arr])[:])
    Ddd = diagm(0 => (energies_N0_left[level_arr]' .+ energies_N0_right[level_arr])[:])

    if return_dicts
        idx_to_state = (string.(level_arr') .* "_" .* string.(level_arr))[:]
        return Duu, Dud, Ddu, Ddd, idx_to_state
    end
    
    return Duu, Dud, Ddu, Ddd
end


# Return four ↑↑, ↑↓, ↓↑, ↓↓
function compute_diagonal_hamiltonians(
    energies_N0::Vector{Float64},
    energies_N1M::Vector{Float64},
    level_arr::Vector{Int64};
    return_dicts=false
)

    Duu = diagm(0 => (energies_N1M[level_arr]' .+ energies_N1M[level_arr])[:])
    Dud = diagm(0 => (energies_N1M[level_arr]' .+ energies_N0[level_arr])[:])
    Ddu = diagm(0 => (energies_N0[level_arr]' .+ energies_N1M[level_arr])[:])
    Ddd = diagm(0 => (energies_N0[level_arr]' .+ energies_N0[level_arr])[:])

    if return_dicts
        idx_to_state = (string.(level_arr') .* "_" .* string.(level_arr))[:]
        return Duu, Dud, Ddu, Ddd, idx_to_state
    end
    
    return Duu, Dud, Ddu, Ddd
end


function single_molecule_drive(
    Omega::Float64,
    states_N0::Matrix{Float64},
    states_N1M::Matrix{Float64}
)
    overlap_n1m_n0 = states_N1M' * states_N0;

    return Omega/2*overlap_n1m_n0
end


function compute_twobody_hamiltonian(
    states_N0_left::Matrix{Float64}, # phi_beta
    states_N1M_left::Matrix{Float64}, # tilde phi_alpha
    energies_N0_left::Vector{Float64},
    energies_N1M_left::Vector{Float64},
    states_N0_right::Matrix{Float64}, # phi_beta
    states_N1M_right::Matrix{Float64}, # tilde phi_alpha
    energies_N0_right::Vector{Float64},
    energies_N1M_right::Vector{Float64},
    num_lvls::Int64,
    Omega::Float64,
    dE::Float64,
    tweezer_separation_vector::Vector{Float64},
    polarization_vector::Vector{Float64},
    idx_to_position::Vector{Vector{Float64}};
    return_dicts::Bool=false
)
    H_dd_cpl = dipole_interactions_full_integral_vectorized(
        collect(states_N0_left[:,1:num_lvls]'),
        collect(states_N0_right[:,1:num_lvls]'),
        collect(states_N1M_left[:,1:num_lvls]'),
        collect(states_N1M_right[:,1:num_lvls]'),
        tweezer_separation_vector,
        polarization_vector,
        idx_to_position
    );

    H_dr_1body_left = single_molecule_drive(
        Omega,
        states_N0_left[:,1:num_lvls],
        states_N1M_left[:,1:num_lvls],
    );

    H_dr_1body_right = single_molecule_drive(
        Omega,
        states_N0_right[:,1:num_lvls],
        states_N1M_right[:,1:num_lvls],
    );

    H_dr_left = kron(H_dr_1body_left, I(size(H_dr_1body_right,1)))
    H_dr_right = kron(I(size(H_dr_1body_left,1)), H_dr_1body_right)

    Duu, Dud, Ddu, Ddd, idx_to_state_1 = compute_diagonal_hamiltonians(
        energies_N0_left,
        energies_N1M_left,
        energies_N0_right,
        energies_N1M_right,
        collect(1:num_lvls),
        return_dicts=true)

    H = [Duu-2dE*I                 H_dr_left       H_dr_right     zeros(size(Duu)...);
         H_dr_left'                Ddu-dE*I        H_dd_cpl       H_dr_right;
         H_dr_right'               H_dd_cpl'       Dud-dE*I       H_dr_left;
         zeros(size(Duu)...)       H_dr_right'     H_dr_left'     Ddd]

    if return_dicts
        total_dicts = kron(["uu_", "du_", "ud_", "dd_"], idx_to_state_1)
        return H, total_dicts
    end

    return H
end


function compute_twobody_hamiltonian(
    states_N0::Matrix{Float64}, # phi_beta
    states_N1M::Matrix{Float64}, # tilde phi_alpha
    energies_N0::Vector{Float64},
    energies_N1M::Vector{Float64},
    num_lvls::Int64,
    Omega::Float64,
    dE::Float64,
    tweezer_separation_vector::Vector{Float64},
    polarization_vector::Vector{Float64},
    idx_to_position::Vector{Vector{Float64}};
    return_dicts::Bool=false
)
    H_dd_cpl = dipole_interactions_full_integral_vectorized(
        collect(states_N0[:,1:num_lvls]'),
        collect(states_N0[:,1:num_lvls]'),
        collect(states_N1M[:,1:num_lvls]'),
        collect(states_N1M[:,1:num_lvls]'),
        tweezer_separation_vector,
        polarization_vector,
        idx_to_position
    );

    H_dr_1body = single_molecule_drive(
        Omega,
        states_N0[:,1:num_lvls],
        states_N1M[:,1:num_lvls],
    );

    H_dr_left = kron(H_dr_1body, I(size(H_dr_1body,1)))
    H_dr_right = kron(I(size(H_dr_1body,1)), H_dr_1body)

    Duu, Dud, Ddu, Ddd, idx_to_state_1 = compute_diagonal_hamiltonians(energies_N0, energies_N1M, collect(1:num_lvls), return_dicts=true)

    H = [Duu-2dE*I            H_dr_left    H_dr_right  zeros(size(Duu)...);
         H_dr_left'           Ddu-dE*I     H_dd_cpl    H_dr_right;
         H_dr_right'          H_dd_cpl'    Dud-dE*I    H_dr_left;
         zeros(size(Duu)...)  H_dr_right'  H_dr_left'  Ddd]

    if return_dicts
        total_dicts = kron(["uu_", "du_", "ud_", "dd_"], idx_to_state_1)
        return H, total_dicts
    end

    return H
end


function compute_twobody_hamiltonian(
    states_N0::Matrix{Float64}, # phi_beta
    states_N1M::Matrix{Float64}, # tilde phi_alpha
    energies_N0::Vector{Float64},
    energies_N1M::Vector{Float64},
    num_lvls::Int64,
    Omega::Float64,
    tweezer_separation_vector::Vector{Float64},
    polarization_vector::Vector{Float64},
    idx_to_position::Vector{Vector{Float64}};
    return_dicts::Bool=false
)
    H_dd_cpl = dipole_interactions_full_integral_vectorized(
        collect(states_N0[:,1:num_lvls]'),
        collect(states_N0[:,1:num_lvls]'),
        collect(states_N1M[:,1:num_lvls]'),
        collect(states_N1M[:,1:num_lvls]'),
        tweezer_separation_vector,
        polarization_vector,
        idx_to_position
    );

    H_dr_1body = single_molecule_drive(
        Omega,
        states_N0[:,1:num_lvls],
        states_N1M[:,1:num_lvls],
    );

    H_dr_left = kron(H_dr_1body, I(size(H_dr_1body,1)))
    H_dr_right = kron(I(size(H_dr_1body,1)), H_dr_1body)

    Duu, Dud, Ddu, Ddd, idx_to_state_1 = compute_diagonal_hamiltonians(energies_N0, energies_N1M, collect(1:num_lvls), return_dicts=true)
    dE = energies_N1M[1] - energies_N0[1]

    H = [Duu-2dE*I            H_dr_left    H_dr_right  zeros(size(Duu)...);
         H_dr_left'           Ddu-dE*I     H_dd_cpl    H_dr_right;
         H_dr_right'          H_dd_cpl'    Dud-dE*I    H_dr_left;
         zeros(size(Duu)...)  H_dr_right'  H_dr_left'  Ddd]

    if return_dicts
        total_dicts = kron(["uu_", "du_", "ud_", "dd_"], idx_to_state_1)
        return H, total_dicts
    end

    return H
end


function compute_pulse_phase_twobody_hamiltonian(
    states_N0_left::Matrix{Float64}, # phi_beta
    states_N1M_left::Matrix{Float64}, # tilde phi_alpha
    energies_N0_left::Vector{Float64},
    energies_N1M_left::Vector{Float64},
    states_N0_right::Matrix{Float64}, # phi_beta
    states_N1M_right::Matrix{Float64}, # tilde phi_alpha
    energies_N0_right::Vector{Float64},
    energies_N1M_right::Vector{Float64},
    num_lvls::Int64,
    Omega::Float64,
    dE::Float64,
    pulse_phase::Float64,
    tweezer_separation_vector::Vector{Float64},
    polarization_vector::Vector{Float64},
    idx_to_position::Vector{Vector{Float64}};
    return_dicts::Bool=false
)
    H_dd_cpl = dipole_interactions_full_integral_vectorized(
        collect(states_N0_left[:,1:num_lvls]'),
        collect(states_N0_right[:,1:num_lvls]'),
        collect(states_N1M_left[:,1:num_lvls]'),
        collect(states_N1M_right[:,1:num_lvls]'),
        tweezer_separation_vector,
        polarization_vector,
        idx_to_position
    );

    H_dr_1body_left = single_molecule_drive(
        Omega,
        states_N0_left[:,1:num_lvls],
        states_N1M_left[:,1:num_lvls],
    );

    H_dr_1body_right = single_molecule_drive(
        Omega,
        states_N0_right[:,1:num_lvls],
        states_N1M_right[:,1:num_lvls],
    );

    e_iphase = exp(im*pulse_phase)
    H_dr_left = kron(e_iphase*H_dr_1body_left, I(size(H_dr_1body_right,1)))
    H_dr_right = kron(I(size(H_dr_1body_left,1)), e_iphase*H_dr_1body_right)

    Duu, Dud, Ddu, Ddd, idx_to_state_1 = compute_diagonal_hamiltonians(
        energies_N0_left,
        energies_N1M_left,
        energies_N0_right,
        energies_N1M_right,
        collect(1:num_lvls),
        return_dicts=true
    )

    H = [Duu-2dE*I                 H_dr_left       H_dr_right     zeros(size(Duu)...);
         H_dr_left'                Ddu-dE*I        H_dd_cpl       H_dr_right;
         H_dr_right'               H_dd_cpl'       Dud-dE*I       H_dr_left;
         zeros(size(Duu)...)       H_dr_right'     H_dr_left'     Ddd]

    if return_dicts
        @show idx_to_state_1
        total_dicts = kron(["uu_", "du_", "ud_", "dd_"], idx_to_state_1)
        return H, total_dicts
    end

    return H
end


function compute_pulse_phase_twobody_hamiltonian(
    states_N0::Matrix{Float64}, # phi_beta
    states_N1M::Matrix{Float64}, # tilde phi_alpha
    energies_N0::Vector{Float64},
    energies_N1M::Vector{Float64},
    num_lvls::Int64,
    Omega::Float64,
    pulse_phase::Float64,
    tweezer_separation_vector::Vector{Float64},
    polarization_vector::Vector{Float64},
    idx_to_position::Vector{Vector{Float64}};
    return_dicts::Bool=false
)
    H_dd_cpl = dipole_interactions_full_integral_vectorized(
        collect(states_N0[:,1:num_lvls]'),
        collect(states_N0[:,1:num_lvls]'),
        collect(states_N1M[:,1:num_lvls]'),
        collect(states_N1M[:,1:num_lvls]'),
        tweezer_separation_vector,
        polarization_vector,
        idx_to_position
    );

    H_dr_1body = single_molecule_drive(
        Omega,
        states_N0[:,1:num_lvls],
        states_N1M[:,1:num_lvls],
    );

    e_iphase = exp(im*pulse_phase)
    H_dr_left = kron(e_iphase*H_dr_1body, I(size(H_dr_1body,1)))
    H_dr_right = kron(I(size(H_dr_1body,1)), e_iphase*H_dr_1body)

    Duu, Dud, Ddu, Ddd, idx_to_state_1 = compute_diagonal_hamiltonians(energies_N0, energies_N1M, collect(1:num_lvls), return_dicts=true)

    dE = energies_N1M[1] - energies_N0[1]

    H = [Duu-2dE*I            H_dr_left    H_dr_right  zeros(size(Duu)...);
         H_dr_left'           Ddu-dE*I     H_dd_cpl    H_dr_right;
         H_dr_right'          H_dd_cpl'    Dud-dE*I    H_dr_left;
         zeros(size(Duu)...)  H_dr_right'  H_dr_left'  Ddd]

    if return_dicts
        @show idx_to_state_1
        total_dicts = kron(["uu_", "du_", "ud_", "dd_"], idx_to_state_1)
        return H, total_dicts
    end

    return H
end


function compute_pulse_phase_twobody_hamiltonian(
    states_N0::Matrix{Float64}, # phi_beta
    states_N1M::Matrix{Float64}, # tilde phi_alpha
    energies_N0::Vector{Float64},
    energies_N1M::Vector{Float64},
    num_lvls::Int64,
    Omega::Float64,
    dE::Float64,
    pulse_phase::Float64,
    tweezer_separation_vector::Vector{Float64},
    polarization_vector::Vector{Float64},
    idx_to_position::Vector{Vector{Float64}};
    return_dicts::Bool=false
)
    H_dd_cpl = dipole_interactions_full_integral_vectorized(
        collect(states_N0[:,1:num_lvls]'),
        collect(states_N0[:,1:num_lvls]'),
        collect(states_N1M[:,1:num_lvls]'),
        collect(states_N1M[:,1:num_lvls]'),
        tweezer_separation_vector,
        polarization_vector,
        idx_to_position
    );

    H_dr_1body = single_molecule_drive(
        Omega,
        states_N0[:,1:num_lvls],
        states_N1M[:,1:num_lvls],
    );

    e_iphase = exp(im*pulse_phase)
    H_dr_left = kron(e_iphase*H_dr_1body, I(size(H_dr_1body,1)))
    H_dr_right = kron(I(size(H_dr_1body,1)), e_iphase*H_dr_1body)

    Duu, Dud, Ddu, Ddd, idx_to_state_1 = compute_diagonal_hamiltonians(energies_N0, energies_N1M, collect(1:num_lvls), return_dicts=true)

    H = [Duu-2dE*I            H_dr_left    H_dr_right  zeros(size(Duu)...);
         H_dr_left'           Ddu-dE*I     H_dd_cpl    H_dr_right;
         H_dr_right'          H_dd_cpl'    Dud-dE*I    H_dr_left;
         zeros(size(Duu)...)  H_dr_right'  H_dr_left'  Ddd]

    if return_dicts
        @show idx_to_state_1
        total_dicts = kron(["uu_", "du_", "ud_", "dd_"], idx_to_state_1)
        return H, total_dicts
    end

    return H
end


function single_molecule_hamiltonian(
    energies_N0::Vector{Float64},
    energies_N1M::Vector{Float64},
    Omega::Float64,
    states_N0::Matrix{Float64},
    states_N1M::Matrix{Float64}
)
    return single_molecule_hamiltonian(energies_N0, energies_N1M, Omega, 0., states_N0, states_N1M)
end


function single_molecule_hamiltonian(
    energies_N0::Vector{Float64},
    energies_N1M::Vector{Float64},
    Omega::Float64,
    phi::Float64,
    states_N0::Matrix{Float64},
    states_N1M::Matrix{Float64}
)
    overlap_n1m_n0 = states_N1M' * states_N0;
    overlap_n0_n1m = states_N0' * states_N1M;

    D_N0 = diagm(0 => energies_N0);
    D_N1M = diagm(0 => energies_N1M);

    H = [D_N0 Omega*exp(im*phi)/2*overlap_n0_n1m; Omega*exp(-im*phi)/2*overlap_n1m_n0 D_N1M];

    return H
end


function single_molecule_hamiltonian(
    energies_N0::Vector{Float64},
    energies_N1M::Vector{Float64},
    Omega::Float64,
    Delta::Float64,
    phi::Float64,
    states_N0::Matrix{Float64},
    states_N1M::Matrix{Float64}
)
    overlap_n1m_n0 = states_N1M' * states_N0;
    overlap_n0_n1m = states_N0' * states_N1M;

    D_N0 = diagm(0 => energies_N0);
    D_N1M = diagm(0 => energies_N1M);

    H = [D_N0 Omega*exp(im*phi)/2*overlap_n0_n1m; Omega*exp(-im*phi)/2*overlap_n1m_n0 D_N1M-Delta*I];

    return H
end

function compute_pulse_phase_twobody_hamiltonian(
    states_N0_left::Matrix{Float64}, # phi_beta
    states_N1M_left::Matrix{Float64}, # tilde phi_alpha
    energies_N0_left::Vector{Float64},
    energies_N1M_left::Vector{Float64},
    states_N0_right::Matrix{Float64}, # phi_beta
    states_N1M_right::Matrix{Float64}, # tilde phi_alpha
    energies_N0_right::Vector{Float64},
    energies_N1M_right::Vector{Float64},
    num_lvls_max::Int64,
    sub_manifold_arr::Vector{Vector{Tuple{Int64,Int64}}},
    Omega::Float64,
    dE::Float64,
    pulse_phase::Float64,
    tweezer_separation_vector::Vector{Float64},
    polarization_vector::Vector{Float64},
    idx_to_position::Vector{Vector{Float64}};
    return_dicts::Bool=false
)
    # Energy manifold: Vector of manifolds
    # Each manifold is a set (Vector) of two integers, corresponding to the left and right motional states
    H_dd_tot, total_dicts = compute_pulse_phase_twobody_hamiltonian(
        states_N0_left, # phi_beta
        states_N1M_left, # tilde phi_alpha
        energies_N0_left,
        energies_N1M_left,
        states_N0_right, # phi_beta
        states_N1M_right, # tilde phi_alpha
        energies_N0_right,
        energies_N1M_right,
        num_lvls_max,
        Omega,
        dE,
        pulse_phase,
        tweezer_separation_vector,
        polarization_vector,
        idx_to_position,
        return_dicts=true
    )
    state_to_ind = Dict(total_dicts .=> 1:length(total_dicts))

    H_arr = Vector{Matrix{ComplexF64}}()
    idx_to_state_dict_arr = Vector{Vector{String}}()

    for energy_manifold in sub_manifold_arr
        energy_key_strings = vcat(
            [string("uu_", site1, "_", site2) for (site1, site2) in energy_manifold],
            [string("ud_", site1, "_", site2) for (site1, site2) in energy_manifold],
            [string("du_", site1, "_", site2) for (site1, site2) in energy_manifold],
            [string("dd_", site1, "_", site2) for (site1, site2) in energy_manifold],
        )
        energy_idxs = [state_to_ind[kk] for kk in energy_key_strings]

        push!(H_arr, H_dd_tot[energy_idxs, energy_idxs])
        push!(idx_to_state_dict_arr, energy_key_strings)
    end

    if return_dicts
        return H_arr, idx_to_state_dict_arr
    end

    return H_arr
end
